查看原文
其他

为节约而生:从标准Attention到稀疏Attention

苏剑林 PaperWeekly 2022-03-17


作者丨苏剑林

单位丨追一科技

研究方向丨NLP,神经网络

个人主页丨kexue.fm


如今 NLP 领域,Attention 大行其道,当然也不止 NLP,在 CV 领域 Attention 也占有一席之地(Non Local、SAGAN 等)。在 18 年初一文读懂「Attention is All You Need」| 附代码实现一文中,我们就已经讨论过 Attention 机制,Attention 的核心在于 Q,K,V 三个向量序列的交互和融合,其中 Q,K 的交互给出了两两向量之间的某种相关度(权重),而最后的输出序列则是把 V 按照权重求和得到的。 


显然,众多 NLP & CV 的成果已经充分肯定了 Attention 的有效性。本文我们将会介绍 Attention 的一些变体,这些变体的共同特点是——“为节约而生”——既节约时间,也节约显存。


背景简述


Attention is All You Need 一文讨论的我们称之为“乘性 Attention”,目前用得比较广泛的也就是这种 Attention: 



另外还有加性 Attention,但加性 Attention 并行不大容易实现(或者实现起来占用显存多),所以一般只用来将变长向量序列编码为固定长度的向量(取代简单的 Pooling),而很少用来做序列到序列的编码。


而在乘性 Attention 中,用得最广泛的当数 Self Attention 了,这种情况下 Q,K,V 都是同一个 X 经过线性变换之后的结果,这样一来输出结果就是跟 X 一样长的向量序列,并且能够直接捕捉X中任意两个向量的关联,而且易于并行,这都是 Self Attention 的优点。 


然而,从理论上来讲,Self Attention 的计算时间和显存占用量都是级别的(n 是序列长度),这就意味着如果序列长度变成原来的 2 倍,显存占用量就是原来的 4 倍,计算时间也是原来的 4 倍。当然,假设并行核心数足够多的情况下,计算时间未必会增加到原来的 4 倍,但是显存的 4 倍却是实实在在的,无可避免,这也是微调 Bert 的时候时不时就来个 OOM 的原因了。


稀疏Attention


我们说 Self Attention 是的,那是因为它要对序列中的任意两个向量都要计算相关度,得到一个大小的相关度矩阵:


 标准Self Attention的注意力矩阵(左)和关联图示(右)


在上图中,左边显示了注意力矩阵,右变显示了关联性,这表明每个元素都跟序列内所有元素有关联。 


所以,如果要节省显存,加快计算速度,那么一个基本的思路就是减少关联性的计算,也就是认为每个元素只跟序列内的一部分元素相关,这就是稀疏 Attention 的基本原理。


本文所要介绍的稀疏 Attention,源于 OpenAI 的论文 Generating Long Sequences with Sparse Transformers,但没有按照原论文的方式来介绍,而是用一种笔者认为更加自然的思路来介绍。 


Atrous Self Attention 


第一个要引入的概念是 Atrous Self Attention,中文可以称之为“膨胀自注意力”、“空洞自注意力”、“带孔自注意力”等。这个名称跟后面的 Local Self Attention 一样,都是笔者根据它的特性自行命名的,原论文 Generating Long Sequences with Sparse Transformers 并没有出现过这两个概念,但我认为将它们单独引出来是有意义的。 


很显然,Atrous Self Attention 就是启发于“膨胀卷积(Atrous Convolution)”,如下右图所示,它对相关性进行了约束,强行要求每个元素只跟它相对距离为 k,2k,3k,…的元素关联,其中 k>1 是预先设定的超参数。从下左的注意力矩阵看,就是强行要求相对距离不是k的倍数的注意力为 0(白色代表 0):


 Atrous Self Attention的注意力矩阵(左)和关联图示(右)


由于现在计算注意力是“跳着”来了,所以实际上每个元素只跟大约 n/k 个元素算相关性,这样一来理想情况下运行效率和显存占用都变成了,也就是说能直接降低到原来的 1/k。 


Local Self Attention


另一个要引入的过渡概念是 Local Self Attention,中文可称之为“局部自注意力”。其实自注意力机制在 CV 领域统称为“Non Local”,而显然 Local Self Attention 则要放弃全局关联,重新引入局部关联。具体来说也很简单,就是约束每个元素只与前后 k 个元素以及自身有关联,如下图所示:


 Local Self Attention的注意力矩阵(左)和关联图示(右)


从注意力矩阵来看,就是相对距离超过 k 的注意力都直接设为 0。 


其实 Local Self Attention 就跟普通卷积很像了,都是保留了一个 2k+1 大小的窗口,然后在窗口内进行一些运算,不同的是普通卷积是把窗口展平然后接一个全连接层得到输出,而现在是窗口内通过注意力来加权平均得到输出。


对于 Local Self Attention 来说,每个元素只跟 2k+1 个元素算相关性,这样一来理想情况下运行效率和显存占用都变成了 𝒪((2k+1)n)∼𝒪(kn) 了,也就是说随着 n 而线性增长,这是一个很理想的性质——当然也直接牺牲了长程关联性。 


Sparse Self Attention 


到此,就可以很自然地引入 OpenAI 的 Sparse Self Attention 了。我们留意到, Atrous Self Attention 是带有一些洞的,而 Local Self Attention 正好填补了这些洞,所以一个简单的方式就是将 Local Self Attention 和 Atrous Self Attention 交替使用,两者累积起来,理论上也可以学习到全局关联性,也省了显存。


简单画个草图就可以知道,假如第一层用 Local Self Attention 的话,那么输出的每个向量都融合了局部的几个输入向量,然后第二层用 Atrous Self Attention,虽然它是跳着来,但是因为第一层的输出融合了局部的输入向量,所以第二层的输出理论上可以跟任意的输入向量相关,也就是说实现了长程关联。


但是 OpenAI 没有这样做,它直接将两个 Atrous Self Attention 和 Local Self Attention 合并为一个,如下图:


 Sparse Self Attention的注意力矩阵(左)和关联图示(右)


从注意力矩阵上看就很容易理解了,就是除了相对距离不超过 k 的、相对距离为 k,2k,3k,… 的注意力都设为 0,这样一来 Attention 就具有“局部紧密相关和远程稀疏相关”的特性,这对很多任务来说可能是一个不错的先验,因为真正需要密集的长程关联的任务事实上是很少的。


代码实现


上面的 Atrous Self Attention、Local Self Attention、Sparse Self Attention 都算是稀疏 Attention,直观上来看就是注意力矩阵变得很稀疏了。那怎么实现它们呢?如果直接在注意力矩阵中对为零的部分进行 mask 的话,那在数学上(功能上)是没有问题的,但这样做并不能提速,也不能省显存。 


官方实现


OpenAI 也开源了自己的实现,位于:


https://github.com/openai/sparse_attention 


这是基于 tensorflow 的,还用到了它们自己的一个稀疏矩阵库 blocksparse。不过这玩意似乎封装得很奇怪,我不知道怎么将它迁移到 Keras,而且它用了很多 Python 3 的特性,不能直接用于 Python 2。如果用Python 3和纯Tensorflow的朋友可以试试。 


还有一个问题是 OpenAI 原论文主要是用稀疏 Attention 来生成超长序列,所以它在论文中和代码中都把注意力矩阵的所有上三角部分都 mask 了(避免用到未来信息),但未必所有用到稀疏 Attention 的都是生成场景,而且对于基本概念的介绍来说,这是不必要的,这也是笔者不按原论文的思路来介绍的原因之一。 


Keras实现


对于 Keras,笔者根据自己构思的写法实现了上述三种稀疏 Attention,并且和原来写过的 Attention 代码统一规范化了一下,还是放在原来的位置: 


https://github.com/bojone/attention/blob/master/attention_keras.py 


经过实验,发现在笔者的写法中,这三种稀疏 Attention 相比全量 Attention 确实能节省一些内存,但遗憾的是,除了 Atrous Self Attention 外,剩下两种 Attention 的实现都不能提速,反而降低了一点速度,这是因为实现过程中没有充分利用稀疏性所致的,而 OpenAI 的 blocksparse 则是经过高度优化,而且是直接写的 CUDA 代码,这没法比。但不管速度如何,三种稀疏 Attention 功能上应该是没毛病的。


代码实现


也没什么好总结的了,就介绍并实现了三种稀疏 Attention。除了省显存外,稀疏的 Attention 应该能够更好地适应一些任务,毕竟大多数任务的关联主要都在局部的,而且是从局部到整体的形式。尤其是最后一个 Sparse Self Attention 所体现的“局部紧密相关和远程稀疏相关”,应当能满足大多数任务的特点,如果有相应任务的读者,不妨试用一下。




点击以下标题查看作者其他文章: 





#投 稿 通 道#

 让你的论文被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学习心得技术干货。我们的目的只有一个,让知识真正流动起来。


📝 来稿标准:

• 稿件确系个人原创作品,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向) 

• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接 

• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志


📬 投稿邮箱:

• 投稿邮箱:hr@paperweekly.site 

• 所有文章配图,请单独在附件中发送 

• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通




🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧



关于PaperWeekly


PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。


▽ 点击 | 阅读原文 | 查看作者博客

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存